{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MusicGen-Style\n",
    "Welcome to MusicGen-Style's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen-Style in different settings.\n",
    "\n",
    "First, we start by initializing MusicGen-Style."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.models import MusicGen\n",
    "from audiocraft.models import MultiBandDiffusion\n",
    "\n",
    "USE_DIFFUSION_DECODER = False\n",
    "\n",
    "model = MusicGen.get_pretrained('facebook/musicgen-style')\n",
    "if USE_DIFFUSION_DECODER:\n",
    "    mbd = MultiBandDiffusion.get_mbd_musicgen()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let us configure the generation parameters. Specifically, you can control the following:\n",
    "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
    "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n",
    "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n",
    "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n",
    "* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.\n",
    "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n",
    "* `cfg_coef_beta` (float, optional): If not None, we use double CFG. cfg_coef_beta is the parameter that pushes the text. Defaults to None, user should start at 5.\n",
    "    If the generated music adheres to much to the text, the user should reduce this parameter. If the music adheres too much to the style conditioning, \n",
    "    the user should increase it\n",
    "\n",
    "When left unchanged, MusicGen will revert to its default parameters.\n",
    "\n",
    "These are the conditioner parameters for the style conditioner:\n",
    "* `eval_q` (int): integer between 1 and 6 included that tells how many quantizers are used in the RVQ bottleneck\n",
    "    of the style conditioner. The higher eval_q is, the more style information passes through the model.\n",
    "* `excerpt_length` (float): float between 1.5 and 4.5 that indicates which length is taken from the audio \n",
    "    conditioning to extract style. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_generation_params(\n",
    "    use_sampling=True,\n",
    "    top_k=250,\n",
    "    duration=30\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model can perform text-to-music, style-to-music and text-and-style-to-music.\n",
    "* Text-to-music can be done using `model.generate`, or `model.generate_with_chroma` with the wav condition being None. \n",
    "* Style-to-music and Text-and-Style-to-music can be done using `model.generate_with_chroma`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text-to-Music"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "model.set_generation_params(\n",
    "    duration=8, # generate 8 seconds, can go up to 30\n",
    "    use_sampling=True, \n",
    "    top_k=250,\n",
    "    cfg_coef=3., # Classifier Free Guidance coefficient \n",
    "    cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning\n",
    ")\n",
    "\n",
    "output = model.generate(\n",
    "    descriptions=[\n",
    "        '80s pop track with bassy drums and synth',\n",
    "        '90s rock song with loud guitars and heavy drums',\n",
    "        'Progressive rock drum and bass solo',\n",
    "        'Punk Rock song with loud drum and power guitar',\n",
    "        'Bluesy guitar instrumental with soulful licks and a driving rhythm section',\n",
    "        'Jazz Funk song with slap bass and powerful saxophone',\n",
    "        'drum and bass beat with intense percussions'\n",
    "    ],\n",
    "    progress=True, return_tokens=True\n",
    ")\n",
    "display_audio(output[0], sample_rate=32000)\n",
    "if USE_DIFFUSION_DECODER:\n",
    "    out_diffusion = mbd.tokens_to_wav(output[1])\n",
    "    display_audio(out_diffusion, sample_rate=32000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Style-to-Music\n",
    "For Style-to-Music, we don't need double CFG. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchaudio\n",
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "model.set_generation_params(\n",
    "    duration=8, # generate 8 seconds, can go up to 30\n",
    "    use_sampling=True, \n",
    "    top_k=250,\n",
    "    cfg_coef=3., # Classifier Free Guidance coefficient \n",
    "    cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning\n",
    ")\n",
    "\n",
    "model.set_style_conditioner_params(\n",
    "    eval_q=1, # integer between 1 and 6\n",
    "              # eval_q is the level of quantization that passes\n",
    "              # through the conditioner. When low, the models adheres less to the \n",
    "              # audio conditioning\n",
    "    excerpt_length=3., # the length in seconds that is taken by the model in the provided excerpt\n",
    "    )\n",
    "\n",
    "melody_waveform, sr = torchaudio.load(\"../assets/electronic.mp3\")\n",
    "melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)\n",
    "output = model.generate_with_chroma(\n",
    "    descriptions=[None, None], \n",
    "    melody_wavs=melody_waveform,\n",
    "    melody_sample_rate=sr,\n",
    "    progress=True, return_tokens=True\n",
    ")\n",
    "display_audio(output[0], sample_rate=32000)\n",
    "if USE_DIFFUSION_DECODER:\n",
    "    out_diffusion = mbd.tokens_to_wav(output[1])\n",
    "    display_audio(out_diffusion, sample_rate=32000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text-and-Style-to-Music\n",
    "For Text-and-Style-to-Music, if we use simple classifier free guidance, the models tends to ignore the text conditioning. We then, introduce double classifier free guidance \n",
    "$$l_{\\text{double CFG}} = l_{\\emptyset} + \\alpha [l_{style} + \\beta(l_{text, style} - l_{style}) - l_{\\emptyset}]$$\n",
    "\n",
    "For $\\beta=1$ we retrieve classic CFG but if $\\beta > 1$ we boost the text condition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchaudio\n",
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "model.set_generation_params(\n",
    "    duration=8, # generate 8 seconds, can go up to 30\n",
    "    use_sampling=True, \n",
    "    top_k=250,\n",
    "    cfg_coef=3., # Classifier Free Guidance coefficient \n",
    "    cfg_coef_beta=5., # double CFG is necessary for text-and-style conditioning\n",
    "                   # Beta in the double CFG formula. between 1 and 9. When set to 1 \n",
    "                   # it is equivalent to normal CFG. \n",
    ")\n",
    "\n",
    "model.set_style_conditioner_params(\n",
    "    eval_q=1, # integer between 1 and 6\n",
    "              # eval_q is the level of quantization that passes\n",
    "              # through the conditioner. When low, the models adheres less to the \n",
    "              # audio conditioning\n",
    "    excerpt_length=3., # the length in seconds that is taken by the model in the provided excerpt\n",
    "    )\n",
    "\n",
    "melody_waveform, sr = torchaudio.load(\"../assets/electronic.mp3\")\n",
    "melody_waveform = melody_waveform.unsqueeze(0).repeat(3, 1, 1)\n",
    "\n",
    "descriptions = [\"8-bit old video game music\", \"Chill lofi remix\", \"80s New wave with synthesizer\"]\n",
    "\n",
    "output = model.generate_with_chroma(\n",
    "    descriptions=descriptions,\n",
    "    melody_wavs=melody_waveform,\n",
    "    melody_sample_rate=sr,\n",
    "    progress=True, return_tokens=True\n",
    ")\n",
    "display_audio(output[0], sample_rate=32000)\n",
    "if USE_DIFFUSION_DECODER:\n",
    "    out_diffusion = mbd.tokens_to_wav(output[1])\n",
    "    display_audio(out_diffusion, sample_rate=32000)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "vscode": {
   "interpreter": {
    "hash": "b02c911f9b3627d505ea4a19966a915ef21f28afb50dbf6b2115072d27c69103"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
